#' @title Classification Imbalanced Random Forest Src Learner
#' @author HarutyunyanLiana
#' @name mlr_learners_classif.imbalanced_rfsrc
#'
#' @description
#' Imbalanced Random forest for classification between two classes.
#' Calls [randomForestSRC::imbalanced.rfsrc()] from from \CRANpkg{randomForestSRC}.
#'
#' @section Custom mlr3 parameters:
#' - `mtry`:
#'   - This hyperparameter can alternatively be set via the added hyperparameter `mtry.ratio`
#'     as `mtry = max(ceiling(mtry.ratio * n_features), 1)`.
#'     Note that `mtry` and `mtry.ratio` are mutually exclusive.
#' - `sampsize`:
#'   - This hyperparameter can alternatively be set via the added hyperparameter `sampsize.ratio`
#'     as `sampsize = max(ceiling(sampsize.ratio * n_obs), 1)`.
#'     Note that `sampsize` and `sampsize.ratio` are mutually exclusive.
#'
#' @section Custom mlr3 defaults:
#' - `cores`:
#'   - Actual default: Auto-detecting the number of cores
#'   - Adjusted default: 1
#'   - Reason for change: Threading conflicts with explicit parallelization via \CRANpkg{future}.
#'
#' @templateVar id classif.imbalanced_rfsrc
#' @template learner
#'
#' @references
#' `r format_bib("obrien2019imbrfsrc", "chen2004imbrf")`
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerClassifImbalancedRandomForestSRC = R6Class("LearnerClassifImbalancedRandomForestSRC",
  inherit = LearnerClassif,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      ps = ps(
        ntree = p_int(default = 3000, lower = 1L, tags = "train"),
        method = p_fct(
          default = "rfq",
          levels = c("rfq", "brf", "standard"),
          tags = "train"
        ),
        block.size = p_int(default = 10L, lower = 1L, tags = c("train", "predict")),
        fast = p_lgl(default = FALSE, tags = "train"),
        ratio = p_dbl(0, 1, tags = "train"),

        mtry = p_int(lower = 1L, tags = "train"),
        mtry.ratio = p_dbl(lower = 0, upper = 1, tags = "train"),
        nodesize = p_int(default = 15L, lower = 1L, tags = "train"),
        nodedepth = p_int(lower = 1L, tags = "train"),
        splitrule = p_fct(
            levels = c("gini", "auc", "entropy"),
            default = "gini", tags = "train"),
        nsplit = p_int(lower = 0, default = 10, tags = "train"),
        importance = p_fct(
            default = "FALSE",
            levels = c("FALSE", "TRUE", "none", "permute", "random", "anti"),
            tags = c("train", "predict")),
        bootstrap = p_fct(
            default = "by.root",
            levels = c("by.root", "by.node", "none", "by.user"), tags = "train"),
        samptype = p_fct(
            default = "swor", levels = c("swor", "swr"),
            tags = "train"),
        samp = p_uty(tags = "train"),
        membership = p_lgl(default = FALSE, tags = c("train", "predict")),
        sampsize = p_uty(tags = "train"),
        sampsize.ratio = p_dbl(0, 1, tags = "train"),
          na.action = p_fct(
            default = "na.omit", levels = c("na.omit", "na.impute"),
            tags = c("train", "predict")),
        nimpute = p_int(default = 1L, lower = 1L, tags = "train"),
        ntime = p_int(lower = 1L, tags = "train"),
        cause = p_int(lower = 1L, tags = "train"),
        proximity = p_fct(
          default = "FALSE",
          levels = c("FALSE", "TRUE", "inbag", "oob", "all"),
          tags = c("train", "predict")),
        distance = p_fct(
          default = "FALSE",
          levels = c("FALSE", "TRUE", "inbag", "oob", "all"),
          tags = c("train", "predict")),
        forest.wt = p_fct(
          default = "FALSE",
          levels = c("FALSE", "TRUE", "inbag", "oob", "all"),
          tags = c("train", "predict")),
        xvar.wt = p_uty(tags = "train"),
        split.wt = p_uty(tags = "train"),
        forest = p_lgl(default = TRUE, tags = "train"),
        var.used = p_fct(
          default = "FALSE",
          levels = c("FALSE", "all.trees", "by.tree"), tags = c("train", "predict")),
        split.depth = p_fct(
          default = "FALSE",
          levels = c("FALSE", "all.trees", "by.tree"), tags = c("train", "predict")),
        seed = p_int(upper = -1L, tags = c("train", "predict")),
        do.trace = p_lgl(default = FALSE, tags = c("train", "predict")),
        statistics = p_lgl(default = FALSE, tags = c("train", "predict")),
        get.tree = p_uty(tags = "predict"),
          outcome = p_fct(
            default = "train", levels = c("train", "test"),
            tags = "predict"),
          ptn.count = p_int(default = 0L, lower = 0L, tags = "predict"),
          cores = p_int(default = 1L, lower = 1L, tags = c("train", "predict", "threads")),
          save.memory = p_lgl(default = FALSE, tags = "train"),
        perf.type = p_fct(levels = c("gmean", "misclass", "brier", "none"), tags = "train") # nolint
      )

      super$initialize(
        id = "classif.imbalanced_rfsrc",
        packages = "randomForestSRC",
        feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
        predict_types = c("response", "prob"),
        param_set = ps,
        properties = c("weights", "missings", "importance", "oob_error", "twoclass"),
        man = "mlr3extralearners::mlr_learners_classif.imbalanced_rfsrc",
        label = "Imbalanced Random Forest"
      )
    },
    #' @description
    #' The importance scores are extracted from the slot `importance`.
    #' @return Named `numeric()`.
    importance = function() {
      if (is.null(self$model$importance) & !is.null(self$model)) {
        stopf("Set 'importance' to one of: {'TRUE', 'permute', 'random', 'anti'}.")
      }

      sort(self$model$importance[, 1], decreasing = TRUE)
    },
    #' @description
    #' Selected features are extracted from the model slot `var.used`.
    #' @return `character()`.
    selected_features = function() {
      if (is.null(self$model$var.used) & !is.null(self$model)) {
        stopf("Set 'var.used' to one of: {'all.trees', 'by.tree'}.")
      }

      names(self$model$var.used)
    },

    #' @description
    #' OOB error extracted from the model slot `err.rate`.
    #' @return `numeric()`.
    oob_error = function() {
      as.numeric(self$model$err.rate[self$model$ntree, 1])
    }
  ),
  private = list(
    .train = function(task) {
      pv = self$param_set$get_values(tags = "train")
      pv = convert_ratio(pv, "mtry", "mtry.ratio", length(task$feature_names))
      pv = convert_ratio(pv, "sampsize", "sampsize.ratio", task$nrow)
      cores = pv$cores %??% 1L

      if ("weights" %in% task$properties) {
        pv$case.wt = as.numeric(task$weights$weight) # nolint
      }

      invoke(randomForestSRC::imbalanced.rfsrc,
             formula = task$formula(), data = data.table::setDF(task$data()),
             .args = pv, .opts = list(rf.cores = cores))
    },
    .predict = function(task) {
      newdata = data.table::setDF(ordered_features(task, self))
      pars = self$param_set$get_values(tags = "predict")
      cores = pars$cores %??% 1L
      pred = invoke(predict,
                    object = self$model,
                    newdata = newdata,
                    .args = pars,
                    .opts = list(rf.cores = cores))

      if (self$predict_type == "response") {
        list(response = pred$class)
      } else {
        list(prob = pred$predicted)
      }
    }
  )
)

.extralrns_dict$add("classif.imbalanced_rfsrc", LearnerClassifImbalancedRandomForestSRC)
